import hashlib
import json
import argparse
import os
from typing import List
import numpy as np
import time
import csv
import re
import statistics

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

hash_time_accumulator = 0.0
hash_count = 0

def hash_data(data: str) -> str:
    global hash_time_accumulator, hash_count
    start_time = time.perf_counter()
    result = hashlib.sha256(data.encode("utf-8")).hexdigest()
    hash_time_accumulator += time.perf_counter() - start_time
    hash_count += 1
    return result

def construct_flat_merkle_tree(leaves: List[str]) -> List[str]:
    n = len(leaves)
    if n == 0:
        raise ValueError("Empty leaves")

    leaf_count = 1 << (n - 1).bit_length()
    padded_leaves = leaves + [leaves[-1]] * (leaf_count - n)
    total_nodes = 2 * leaf_count - 1
    tree = [None] * total_nodes

    for i in range(leaf_count):
        tree[leaf_count - 1 + i] = hash_data(padded_leaves[i])
    for i in range(leaf_count - 2, -1, -1):
        left = tree[2 * i + 1]
        right = tree[2 * i + 2]
        tree[i] = hash_data(left + right)
    return tree

def benchmark_directory(directory: str, output_csv: str, repeat: int = 5):
    results = []

    for filename in os.listdir(directory):
        if filename.endswith(".npy"):
            match = re.match(r"(\d+)_(\d+)_.*\.npy", filename)
            if not match:
                print(f"[!] Skipped unrecognized file name format: {filename}")
                continue
            n, dim = int(match.group(1)), int(match.group(2))
            path = os.path.join(directory, filename)

            try:
                vectors = np.load(path)
                if vectors.shape != (n, dim):
                    print(f"[!] Skipped malformed shape in {filename}: {vectors.shape}")
                    continue
                data = [",".join(map(str, vec.tolist())) for vec in vectors]

                total_time_list = []
                avg_hash_time_list = []
                hash_counts = []

                for i in range(repeat):
                    global hash_time_accumulator, hash_count
                    hash_time_accumulator = 0.0
                    hash_count = 0

                    start_time = time.perf_counter()
                    _ = construct_flat_merkle_tree(data)
                    elapsed = time.perf_counter() - start_time

                    avg_us = (hash_time_accumulator / hash_count * 1e6) if hash_count > 0 else 0
                    total_time_list.append(elapsed)
                    avg_hash_time_list.append(avg_us)
                    hash_counts.append(hash_count)

                total_avg_time = statistics.mean(total_time_list)
                min_time = min(total_time_list)
                max_time = max(total_time_list)
                std_time = statistics.stdev(total_time_list) if repeat > 1 else 0.0

                avg_hash_time = statistics.mean(avg_hash_time_list)
                avg_hash_count = int(round(statistics.mean(hash_counts)))

                print(f"[✓] {filename} avg={total_avg_time:.6f}s, min={min_time:.6f}s, max={max_time:.6f}s")

                results.append({
                    "n": n,
                    "dim": dim,
                    "total_time_sec": f"{total_avg_time:.6f}",
                    "min_time_sec": f"{min_time:.6f}",
                    "max_time_sec": f"{max_time:.6f}",
                    "std_time_sec": f"{std_time:.6f}",
                    "hash_count": avg_hash_count,
                    "avg_hash_time_us": f"{avg_hash_time:.2f}"
                })
            except Exception as e:
                print(f"[!] Error processing {filename}: {e}")

    with open(output_csv, "w", newline="") as f:
        fieldnames = [
            "n", "dim", "total_time_sec",
            "min_time_sec", "max_time_sec", "std_time_sec",
            "hash_count", "avg_hash_time_us"
        ]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print(f"[✓] Benchmark results saved to {output_csv}")

def main():
    parser = argparse.ArgumentParser(description="Merkle Tree Benchmark Tool")
    parser.add_argument("--input_dir", type=str, default="/workspace/0410_nips/2_hashtree/embds", help="Directory with .npy embedding files")
    parser.add_argument("--csv_out", type=str, default="benchmark_results.csv", help="Output CSV file path")
    parser.add_argument("--repeat", type=int, default=5, help="Number of times to repeat for each input")
    args = parser.parse_args()

    benchmark_directory(args.input_dir, args.csv_out, args.repeat)

if __name__ == "__main__":
    main()
